{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "Keras Flowers on TPU (playground).ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "environment": {
      "name": "tf22-cpu.2-2.m47",
      "type": "gcloud",
      "uri": "gcr.io/deeplearning-platform-release/tf22-cpu.2-2:m47"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yd4z24wngWau"
      },
      "source": [
        "Let's train this model on TPU. It's worth it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "89B27-TGiDNB"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9u3d4Z7uQsmp",
        "outputId": "58a37d48-ab7a-4578-b0b9-11356cee67bf"
      },
      "source": [
        "import os, sys, math\n",
        "import numpy as np\n",
        "from matplotlib import pyplot as plt\n",
        "import tensorflow as tf\n",
        "print(\"Tensorflow version \" + tf.__version__)\n",
        "AUTOTUNE = tf.data.AUTOTUNE"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Tensorflow version 2.6.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UaKGHPjWkcVj"
      },
      "source": [
        "## TPU detection"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tmv6p137kgob"
      },
      "source": [
        "try: # detect TPUs\n",
        "    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection\n",
        "    strategy = tf.distribute.TPUStrategy(tpu)\n",
        "except ValueError: # detect GPUs\n",
        "    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines\n",
        "    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU\n",
        "    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines\n",
        "print(\"Number of accelerators: \", strategy.num_replicas_in_sync)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w9S3uKC_iXY5"
      },
      "source": [
        "## Configuration"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M3G-2aUBQJ-H"
      },
      "source": [
        "GCS_PATTERN = 'gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec'\n",
        "IMAGE_SIZE = [192, 192]\n",
        "\n",
        "if tpu:\n",
        "  BATCH_SIZE = 16*strategy.num_replicas_in_sync  # A TPU has 8 cores so this will be 128\n",
        "else:\n",
        "  BATCH_SIZE = 32  # On Colab/GPU, a higher batch size does not help and sometimes does not fit on the GPU (OOM)\n",
        "\n",
        "VALIDATION_SPLIT = 0.19\n",
        "CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # do not change, maps to the labels in the data (folder names)\n",
        "\n",
        "# splitting data files between training and validation\n",
        "filenames = tf.io.gfile.glob(GCS_PATTERN)\n",
        "split = int(len(filenames) * VALIDATION_SPLIT)\n",
        "training_filenames = filenames[split:]\n",
        "validation_filenames = filenames[:split]\n",
        "print(\"Pattern matches {} data files. Splitting dataset into {} training files and {} validation files\".format(len(filenames), len(training_filenames), len(validation_filenames)))\n",
        "validation_steps = int(3670 // len(filenames) * len(validation_filenames)) // BATCH_SIZE\n",
        "steps_per_epoch = int(3670 // len(filenames) * len(training_filenames)) // BATCH_SIZE\n",
        "print(\"With a batch size of {}, there will be {} batches per training epoch and {} batch(es) per validation run.\".format(BATCH_SIZE, steps_per_epoch, validation_steps))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "tMy0zz6FXnJY"
      },
      "source": [
        "#@title display utilities [RUN ME]\n",
        "\n",
        "def dataset_to_numpy_util(dataset, N):\n",
        "  dataset = dataset.batch(N)\n",
        "  \n",
        "  # In eager mode, iterate in the Datset directly.\n",
        "  for images, labels in dataset:\n",
        "    numpy_images = images.numpy()\n",
        "    numpy_labels = labels.numpy()\n",
        "    break;\n",
        "\n",
        "  return numpy_images, numpy_labels\n",
        "\n",
        "def title_from_label_and_target(label, correct_label):\n",
        "  label = np.argmax(label, axis=-1)  # one-hot to class number\n",
        "  correct_label = np.argmax(correct_label, axis=-1) # one-hot to class number\n",
        "  correct = (label == correct_label)\n",
        "  return \"{} [{}{}{}]\".format(CLASSES[label], str(correct), ', shoud be ' if not correct else '',\n",
        "                              CLASSES[correct_label] if not correct else ''), correct\n",
        "\n",
        "def display_one_flower(image, title, subplot, red=False):\n",
        "    plt.subplot(subplot)\n",
        "    plt.axis('off')\n",
        "    plt.imshow(image)\n",
        "    plt.title(title, fontsize=16, color='red' if red else 'black')\n",
        "    return subplot+1\n",
        "  \n",
        "def display_9_images_from_dataset(dataset):\n",
        "  subplot=331\n",
        "  plt.figure(figsize=(13,13))\n",
        "  images, labels = dataset_to_numpy_util(dataset, 9)\n",
        "  for i, image in enumerate(images):\n",
        "    title = CLASSES[np.argmax(labels[i], axis=-1)]\n",
        "    subplot = display_one_flower(image, title, subplot)\n",
        "    if i >= 8:\n",
        "      break;\n",
        "              \n",
        "  #plt.tight_layout()\n",
        "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
        "  plt.show()\n",
        "  \n",
        "def display_9_images_with_predictions(images, predictions, labels):\n",
        "  subplot=331\n",
        "  plt.figure(figsize=(13,13))\n",
        "  for i, image in enumerate(images):\n",
        "    title, correct = title_from_label_and_target(predictions[i], labels[i])\n",
        "    subplot = display_one_flower(image, title, subplot, not correct)\n",
        "    if i >= 8:\n",
        "      break;\n",
        "              \n",
        "  #plt.tight_layout()\n",
        "  plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
        "  plt.show()\n",
        "  \n",
        "def display_training_curves(training, validation, title, subplot):\n",
        "  if subplot%10==1: # set up the subplots on the first call\n",
        "    plt.subplots(figsize=(10,10), facecolor='#F0F0F0')\n",
        "    #plt.tight_layout()\n",
        "  ax = plt.subplot(subplot)\n",
        "  ax.set_facecolor('#F8F8F8')\n",
        "  ax.plot(training)\n",
        "  ax.plot(validation)\n",
        "  ax.set_title('model '+ title)\n",
        "  ax.set_ylabel(title)\n",
        "  ax.set_xlabel('epoch')\n",
        "  ax.legend(['train', 'valid.'])"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kvPXiovhi3ZZ"
      },
      "source": [
        "## Read images and labels from TFRecords"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LtAVr-4CP1rp"
      },
      "source": [
        "def read_tfrecord(example):\n",
        "    features = {\n",
        "        \"image\": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring\n",
        "        \"class\": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar\n",
        "        \"one_hot_class\": tf.io.VarLenFeature(tf.float32),\n",
        "    }\n",
        "    example = tf.io.parse_single_example(example, features)\n",
        "    image = tf.io.decode_jpeg(example['image'], channels=3)\n",
        "    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range\n",
        "    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size will be needed for TPU\n",
        "    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])\n",
        "    one_hot_class = tf.reshape(one_hot_class, [5])\n",
        "    return image, one_hot_class\n",
        "\n",
        "def load_dataset(filenames):\n",
        "  # read from TFRecords. For optimal performance, read from multiple\n",
        "  # TFRecord files at once and set the option experimental_deterministic = False\n",
        "  # to allow order-altering optimizations.\n",
        "\n",
        "  option_no_order = tf.data.Options()\n",
        "  option_no_order.experimental_deterministic = False\n",
        "\n",
        "  dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)\n",
        "  dataset = dataset.with_options(option_no_order)\n",
        "  dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)\n",
        "  return dataset"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xb-b4PRz-V6O"
      },
      "source": [
        "display_9_images_from_dataset(load_dataset(training_filenames))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "22rVDTx8wCqE"
      },
      "source": [
        "## training and validation datasets"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7wxKyCklR4Gh"
      },
      "source": [
        "def get_batched_dataset(filenames, train=False):\n",
        "  dataset = load_dataset(filenames)\n",
        "  dataset = dataset.cache() # This dataset fits in RAM\n",
        "  if train:\n",
        "    # Best practices for Keras:\n",
        "    # Training dataset: repeat then batch\n",
        "    # Evaluation dataset: do not repeat\n",
        "    dataset = dataset.repeat()\n",
        "  dataset = dataset.batch(BATCH_SIZE)\n",
        "  dataset = dataset.prefetch(AUTOTUNE) # prefetch next batch while training (autotune prefetch buffer size)\n",
        "  # should shuffle too but this dataset was well shuffled on disk already\n",
        "  return dataset\n",
        "  # source: Dataset performance guide: https://www.tensorflow.org/guide/performance/datasets\n",
        "\n",
        "# instantiate the datasets\n",
        "training_dataset = get_batched_dataset(training_filenames, train=True)\n",
        "validation_dataset = get_batched_dataset(validation_filenames, train=False)\n",
        "\n",
        "some_flowers, some_labels = dataset_to_numpy_util(load_dataset(validation_filenames), 160)"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ALtRUlxhw8Vt"
      },
      "source": [
        "## Model [WORK REQUIRED]\n",
        "1. train the model as it is, with a single convolutional layer\n",
        " * Accuracy 40%... Not great.\n",
        "2. add additional convolutional layers interleaved with max-pooling layers. Try also adding a second dense layer. For example:<br/>\n",
        "`conv 3x3, 16 filters, relu`<br/>\n",
        "`conv 3x3, 30 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`conv 3x3, 50 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`conv 3x3, 70 filters, relu`<br/>\n",
        "`flatten`<br/>\n",
        "`dense 5 softmax`<br/>\n",
        " * Accuracy 60%... slightly better. But this model is more than 800K parameters and it overfits dramatically (overfitting = eval loss goes up instead of down).\n",
        "3. Try replacing the Flatten layer by Global average pooling.\n",
        " * Accuracy still 60% but the model is back to a modest 50K parameters, and does not overfit anymore. If you train longer, it can go even higher.\n",
        "4. Try experimenting with 1x1 convolutions too. They typically follow a 3x3 convolution and decrease the filter count. You can also add dropout between the dense layers. For example:\n",
        "`conv 3x3, 20 filters, relu`<br/>\n",
        "`conv 3x3, 50 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`conv 3x3, 70 filters, relu`<br/>\n",
        "`conv 1x1, 50 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`conv 3x3, 100 filters, relu`<br/>\n",
        "`conv 1x1, 70 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`conv 3x3, 120 filters, relu`<br/>\n",
        "`conv 1x1, 80 filters, relu`<br/>\n",
        "`max pool 2x2`<br/>\n",
        "`global average pooling`<br/>\n",
        "`dense 5 softmax`<br/>\n",
        " * accuracy 70%\n",
        "5. The goal is 80% accuracy ! Good luck. (You might want to train for more than 20 epochs to get there. Se your trainig curves to see if it is worth training longer.)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XLJNVGwHUDy1"
      },
      "source": [
        "with strategy.scope(): # this line is all that is needed to run on TPU (or multi-GPU, ...)\n",
        "\n",
        "  model = tf.keras.Sequential([\n",
        "\n",
        "      ###\n",
        "      tf.keras.layers.InputLayer(input_shape=[*IMAGE_SIZE, 3]),\n",
        "      tf.keras.layers.Conv2D(kernel_size=3, filters=20, padding='same', activation='relu'),\n",
        "      #\n",
        "      # YOUR LAYERS HERE\n",
        "      #\n",
        "      # LAYERS TO TRY:\n",
        "      # Conv2D(kernel_size=3, filters=30, padding='same', activation='relu')\n",
        "      # MaxPooling2D(pool_size=2)\n",
        "      # GlobalAveragePooling2D() / Flatten()\n",
        "      # Dense(90, activation='relu')\n",
        "      #\n",
        "      tf.keras.layers.GlobalAveragePooling2D(),\n",
        "      tf.keras.layers.Dense(5, activation='softmax')\n",
        "      ###\n",
        "  ])\n",
        "\n",
        "  model.compile(\n",
        "    optimizer='adam',\n",
        "    loss='categorical_crossentropy',\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "  model.summary()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dMfenMQcxAAb"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M-ID7vP5mIKs"
      },
      "source": [
        "EPOCHS = 20\n",
        "\n",
        "history = model.fit(training_dataset, steps_per_epoch=steps_per_epoch, epochs=EPOCHS,\n",
        "                    validation_data=validation_dataset)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VngeUBIdyJ1T"
      },
      "source": [
        "print(history.history.keys())\n",
        "display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)\n",
        "display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MKFMWzh0Yxsq"
      },
      "source": [
        "## Predictions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ehlsvY46Hs9z"
      },
      "source": [
        "# randomize the input so that you can execute multiple times to change results\n",
        "permutation = np.random.permutation(160)\n",
        "some_flowers, some_labels = (some_flowers[permutation], some_labels[permutation])\n",
        "\n",
        "predictions = model.predict(some_flowers, batch_size=16)\n",
        "evaluations = model.evaluate(some_flowers, some_labels, batch_size=16)\n",
        "  \n",
        "print(np.array(CLASSES)[np.argmax(predictions, axis=-1)].tolist())\n",
        "print('[val_loss, val_acc]', evaluations)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qzCCDL1CZFx6"
      },
      "source": [
        "display_9_images_with_predictions(some_flowers, predictions, some_labels)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SVY1pBg5ydH-"
      },
      "source": [
        "## License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hleIN5-pcr0N"
      },
      "source": [
        "\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "author: Martin Gorner<br>\n",
        "twitter: @martin_gorner\n",
        "\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "Copyright 2021 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "This is not an official Google product but sample code provided for an educational purpose\n"
      ]
    }
  ]
}